In [ ]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
In [ ]:
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
In [ ]:
dataset = dset.ImageFolder(root=dataroot,
                          transform=transforms.Compose([
                              transforms.Resize(image_size),
                              transforms.CenterCrop(image_size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ]))

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:3" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
In [ ]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)   
In [ ]:
# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input Z 100维的向量
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),  
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),        
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
    
    def forward(self, input):
        return self.main(input)
In [ ]:
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)
In [ ]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),  
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()            
        )
        
    def forward(self, input):
        return self.main(input)
In [ ]:
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)
In [ ]:
# Initialize BCELoss Function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
In [16]:
# 测试batchsize
# For each epoch
for i, data in enumerate(dataloader, 0): 
    print(data[0].shape) # torch.Size([128, 3, 64, 64])
    print(data[1].shape) # torch.Size([128])
    break
torch.Size([128, 3, 64, 64])
torch.Size([128])
torch.Size([128, 3, 64, 64])
torch.Size([128])
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-16-7a3c56b971fb> in <module>
      1 # 测试batchsize
      2 # For each epoch
----> 3 for i, data in enumerate(dataloader, 0):
      4     print(data[0].shape) # torch.Size([128, 3, 64, 64])
      5     print(data[1].shape)

~/anaconda3/envs/tf/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    361 
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/tf/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    972 
    973             assert not self._shutdown and self._tasks_outstanding > 0
--> 974             idx, data = self._get_data()
    975             self._tasks_outstanding -= 1
    976 

~/anaconda3/envs/tf/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _get_data(self)
    939         else:
    940             while True:
--> 941                 success, data = self._try_get_data()
    942                 if success:
    943                     return data

~/anaconda3/envs/tf/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _try_get_data(self, timeout)
    777         #   (bool: whether successfully get data, any: data if successful else None)
    778         try:
--> 779             data = self._data_queue.get(timeout=timeout)
    780             return (True, data)
    781         except Exception as e:

~/anaconda3/envs/tf/lib/python3.7/multiprocessing/queues.py in get(self, block, timeout)
    102                 if block:
    103                     timeout = deadline - time.monotonic()
--> 104                     if not self._poll(timeout):
    105                         raise Empty
    106                 elif not self._poll():

~/anaconda3/envs/tf/lib/python3.7/multiprocessing/connection.py in poll(self, timeout)
    255         self._check_closed()
    256         self._check_readable()
--> 257         return self._poll(timeout)
    258 
    259     def __enter__(self):

~/anaconda3/envs/tf/lib/python3.7/multiprocessing/connection.py in _poll(self, timeout)
    412 
    413     def _poll(self, timeout):
--> 414         r = wait([self], timeout)
    415         return bool(r)
    416 

~/anaconda3/envs/tf/lib/python3.7/multiprocessing/connection.py in wait(object_list, timeout)
    919 
    920             while True:
--> 921                 ready = selector.select(timeout)
    922                 if ready:
    923                     return [key.fileobj for (key, events) in ready]

~/anaconda3/envs/tf/lib/python3.7/selectors.py in select(self, timeout)
    413         ready = []
    414         try:
--> 415             fd_event_list = self._selector.poll(timeout)
    416         except InterruptedError:
    417             return ready

KeyboardInterrupt: 
In [17]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")

# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):    
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
        
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()
    
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)    # 生成器的Loss应该接近于1,分类是真实的图片
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1       
        
Starting Training Loop...
[0/5][0/1583]	Loss_D: 1.5029	Loss_G: 6.4423	D(x): 0.6653	D(G(z)): 0.5850 / 0.0029
[0/5][50/1583]	Loss_D: 0.4859	Loss_G: 31.8880	D(x): 0.8118	D(G(z)): 0.0000 / 0.0000
[0/5][100/1583]	Loss_D: 0.8345	Loss_G: 14.9826	D(x): 0.7317	D(G(z)): 0.0001 / 0.0000
[0/5][150/1583]	Loss_D: 0.7706	Loss_G: 5.4995	D(x): 0.6385	D(G(z)): 0.0354 / 0.0079
[0/5][200/1583]	Loss_D: 0.6316	Loss_G: 3.1242	D(x): 0.6750	D(G(z)): 0.0621 / 0.0627
[0/5][250/1583]	Loss_D: 0.8689	Loss_G: 6.7555	D(x): 0.8780	D(G(z)): 0.4492 / 0.0021
[0/5][300/1583]	Loss_D: 0.6864	Loss_G: 5.3952	D(x): 0.8877	D(G(z)): 0.3701 / 0.0089
[0/5][350/1583]	Loss_D: 1.1493	Loss_G: 5.9940	D(x): 0.4636	D(G(z)): 0.0047 / 0.0066
[0/5][400/1583]	Loss_D: 0.5025	Loss_G: 3.6571	D(x): 0.8990	D(G(z)): 0.2767 / 0.0464
[0/5][450/1583]	Loss_D: 0.5523	Loss_G: 6.7327	D(x): 0.9030	D(G(z)): 0.2751 / 0.0039
[0/5][500/1583]	Loss_D: 0.4248	Loss_G: 4.2081	D(x): 0.8897	D(G(z)): 0.1936 / 0.0244
[0/5][550/1583]	Loss_D: 0.5262	Loss_G: 5.0320	D(x): 0.9442	D(G(z)): 0.3148 / 0.0134
[0/5][600/1583]	Loss_D: 0.9008	Loss_G: 2.9109	D(x): 0.5280	D(G(z)): 0.0688 / 0.1134
[0/5][650/1583]	Loss_D: 0.4281	Loss_G: 3.6998	D(x): 0.7866	D(G(z)): 0.0843 / 0.0405
[0/5][700/1583]	Loss_D: 0.5641	Loss_G: 3.5155	D(x): 0.6875	D(G(z)): 0.0619 / 0.0573
[0/5][750/1583]	Loss_D: 0.6768	Loss_G: 5.3575	D(x): 0.7802	D(G(z)): 0.2699 / 0.0146
[0/5][800/1583]	Loss_D: 0.3818	Loss_G: 3.8197	D(x): 0.8637	D(G(z)): 0.1678 / 0.0347
[0/5][850/1583]	Loss_D: 0.9934	Loss_G: 6.4453	D(x): 0.9447	D(G(z)): 0.5080 / 0.0041
[0/5][900/1583]	Loss_D: 0.5654	Loss_G: 3.0134	D(x): 0.7214	D(G(z)): 0.1207 / 0.0775
[0/5][950/1583]	Loss_D: 0.5743	Loss_G: 5.5115	D(x): 0.8583	D(G(z)): 0.2919 / 0.0081
[0/5][1000/1583]	Loss_D: 0.5675	Loss_G: 4.8195	D(x): 0.8524	D(G(z)): 0.2725 / 0.0157
[0/5][1050/1583]	Loss_D: 0.6168	Loss_G: 6.1604	D(x): 0.9281	D(G(z)): 0.3655 / 0.0049
[0/5][1100/1583]	Loss_D: 0.3554	Loss_G: 4.3391	D(x): 0.7716	D(G(z)): 0.0489 / 0.0229
[0/5][1150/1583]	Loss_D: 0.5849	Loss_G: 3.0146	D(x): 0.6687	D(G(z)): 0.0344 / 0.0845
[0/5][1200/1583]	Loss_D: 0.4374	Loss_G: 3.2759	D(x): 0.7924	D(G(z)): 0.1280 / 0.0539
[0/5][1250/1583]	Loss_D: 0.4374	Loss_G: 3.6335	D(x): 0.7671	D(G(z)): 0.0924 / 0.0463
[0/5][1300/1583]	Loss_D: 1.5776	Loss_G: 6.7401	D(x): 0.9813	D(G(z)): 0.7012 / 0.0029
[0/5][1350/1583]	Loss_D: 0.4154	Loss_G: 3.0336	D(x): 0.7384	D(G(z)): 0.0472 / 0.0754
[0/5][1400/1583]	Loss_D: 0.5394	Loss_G: 2.1617	D(x): 0.7231	D(G(z)): 0.1209 / 0.1537
[0/5][1450/1583]	Loss_D: 0.4363	Loss_G: 4.1643	D(x): 0.8633	D(G(z)): 0.2162 / 0.0223
[0/5][1500/1583]	Loss_D: 0.4096	Loss_G: 3.2701	D(x): 0.7547	D(G(z)): 0.0658 / 0.0688
[0/5][1550/1583]	Loss_D: 0.7451	Loss_G: 3.1521	D(x): 0.7357	D(G(z)): 0.2708 / 0.0679
[1/5][0/1583]	Loss_D: 0.4592	Loss_G: 4.4285	D(x): 0.8657	D(G(z)): 0.2294 / 0.0233
[1/5][50/1583]	Loss_D: 0.4396	Loss_G: 3.7673	D(x): 0.8459	D(G(z)): 0.2033 / 0.0354
[1/5][100/1583]	Loss_D: 0.5022	Loss_G: 4.2431	D(x): 0.7792	D(G(z)): 0.1530 / 0.0255
[1/5][150/1583]	Loss_D: 0.4253	Loss_G: 4.1007	D(x): 0.8878	D(G(z)): 0.2333 / 0.0252
[1/5][200/1583]	Loss_D: 0.4477	Loss_G: 3.7672	D(x): 0.8657	D(G(z)): 0.2237 / 0.0380
[1/5][250/1583]	Loss_D: 0.5257	Loss_G: 2.6835	D(x): 0.7320	D(G(z)): 0.1356 / 0.1019
[1/5][300/1583]	Loss_D: 0.4746	Loss_G: 3.6705	D(x): 0.8216	D(G(z)): 0.1966 / 0.0408
[1/5][350/1583]	Loss_D: 0.3632	Loss_G: 3.1717	D(x): 0.8682	D(G(z)): 0.1741 / 0.0634
[1/5][400/1583]	Loss_D: 0.6198	Loss_G: 1.6818	D(x): 0.6312	D(G(z)): 0.0636 / 0.2397
[1/5][450/1583]	Loss_D: 0.8857	Loss_G: 2.0437	D(x): 0.5457	D(G(z)): 0.0155 / 0.2036
[1/5][500/1583]	Loss_D: 0.4877	Loss_G: 2.5646	D(x): 0.7139	D(G(z)): 0.0798 / 0.1166
[1/5][550/1583]	Loss_D: 0.7475	Loss_G: 4.1037	D(x): 0.8232	D(G(z)): 0.3498 / 0.0286
[1/5][600/1583]	Loss_D: 0.4717	Loss_G: 1.7062	D(x): 0.7320	D(G(z)): 0.0921 / 0.2273
[1/5][650/1583]	Loss_D: 0.3380	Loss_G: 3.5667	D(x): 0.8362	D(G(z)): 0.1148 / 0.0470
[1/5][700/1583]	Loss_D: 0.6683	Loss_G: 4.0823	D(x): 0.9044	D(G(z)): 0.3873 / 0.0303
[1/5][750/1583]	Loss_D: 1.0230	Loss_G: 6.5927	D(x): 0.9695	D(G(z)): 0.5670 / 0.0027
[1/5][800/1583]	Loss_D: 0.4032	Loss_G: 3.4219	D(x): 0.8592	D(G(z)): 0.1873 / 0.0525
[1/5][850/1583]	Loss_D: 0.7501	Loss_G: 4.3012	D(x): 0.9041	D(G(z)): 0.4211 / 0.0225
[1/5][900/1583]	Loss_D: 0.5186	Loss_G: 3.6637	D(x): 0.8887	D(G(z)): 0.2839 / 0.0411
[1/5][950/1583]	Loss_D: 1.8251	Loss_G: 7.9901	D(x): 0.9845	D(G(z)): 0.7751 / 0.0008
[1/5][1000/1583]	Loss_D: 0.4369	Loss_G: 3.6613	D(x): 0.8947	D(G(z)): 0.2502 / 0.0382
[1/5][1050/1583]	Loss_D: 0.6744	Loss_G: 1.5955	D(x): 0.6012	D(G(z)): 0.0638 / 0.2842
[1/5][1100/1583]	Loss_D: 0.4600	Loss_G: 3.2876	D(x): 0.8318	D(G(z)): 0.2033 / 0.0495
[1/5][1150/1583]	Loss_D: 0.6046	Loss_G: 4.1184	D(x): 0.8902	D(G(z)): 0.3394 / 0.0221
[1/5][1200/1583]	Loss_D: 0.9598	Loss_G: 1.5054	D(x): 0.4873	D(G(z)): 0.0296 / 0.2899
[1/5][1250/1583]	Loss_D: 0.5757	Loss_G: 2.1358	D(x): 0.7770	D(G(z)): 0.2038 / 0.1552
[1/5][1300/1583]	Loss_D: 0.5434	Loss_G: 2.2407	D(x): 0.6909	D(G(z)): 0.1079 / 0.1444
[1/5][1350/1583]	Loss_D: 0.5309	Loss_G: 3.5128	D(x): 0.8650	D(G(z)): 0.2870 / 0.0403
[1/5][1400/1583]	Loss_D: 0.8076	Loss_G: 4.1176	D(x): 0.9058	D(G(z)): 0.4594 / 0.0247
[1/5][1450/1583]	Loss_D: 0.6124	Loss_G: 3.9135	D(x): 0.9096	D(G(z)): 0.3657 / 0.0279
[1/5][1500/1583]	Loss_D: 0.6686	Loss_G: 2.6917	D(x): 0.8309	D(G(z)): 0.3299 / 0.0939
[1/5][1550/1583]	Loss_D: 1.6094	Loss_G: 1.0694	D(x): 0.2752	D(G(z)): 0.0170 / 0.4007
[2/5][0/1583]	Loss_D: 0.4884	Loss_G: 2.5874	D(x): 0.7268	D(G(z)): 0.1054 / 0.1022
[2/5][50/1583]	Loss_D: 1.4683	Loss_G: 1.5614	D(x): 0.3111	D(G(z)): 0.0601 / 0.2895
[2/5][100/1583]	Loss_D: 0.5546	Loss_G: 3.3662	D(x): 0.8921	D(G(z)): 0.3092 / 0.0509
[2/5][150/1583]	Loss_D: 0.6266	Loss_G: 1.4920	D(x): 0.6775	D(G(z)): 0.1456 / 0.2835
[2/5][200/1583]	Loss_D: 0.7972	Loss_G: 4.4892	D(x): 0.9052	D(G(z)): 0.4644 / 0.0166
[2/5][250/1583]	Loss_D: 0.6821	Loss_G: 1.6232	D(x): 0.6864	D(G(z)): 0.2115 / 0.2316
[2/5][300/1583]	Loss_D: 0.7365	Loss_G: 1.5125	D(x): 0.5543	D(G(z)): 0.0641 / 0.2663
[2/5][350/1583]	Loss_D: 0.5135	Loss_G: 3.4470	D(x): 0.8726	D(G(z)): 0.2830 / 0.0467
[2/5][400/1583]	Loss_D: 0.8392	Loss_G: 2.0930	D(x): 0.6734	D(G(z)): 0.2834 / 0.1530
[2/5][450/1583]	Loss_D: 0.6207	Loss_G: 2.2681	D(x): 0.7069	D(G(z)): 0.1999 / 0.1305
[2/5][500/1583]	Loss_D: 0.4368	Loss_G: 2.1653	D(x): 0.7492	D(G(z)): 0.1098 / 0.1412
[2/5][550/1583]	Loss_D: 0.8889	Loss_G: 1.8015	D(x): 0.5062	D(G(z)): 0.0735 / 0.2132
[2/5][600/1583]	Loss_D: 0.3806	Loss_G: 2.8559	D(x): 0.8594	D(G(z)): 0.1852 / 0.0735
[2/5][650/1583]	Loss_D: 1.9121	Loss_G: 2.4695	D(x): 0.7946	D(G(z)): 0.7261 / 0.1318
[2/5][700/1583]	Loss_D: 0.8709	Loss_G: 3.8589	D(x): 0.8184	D(G(z)): 0.4433 / 0.0271
[2/5][750/1583]	Loss_D: 0.5468	Loss_G: 2.0756	D(x): 0.7850	D(G(z)): 0.2318 / 0.1521
[2/5][800/1583]	Loss_D: 1.1574	Loss_G: 0.1841	D(x): 0.4000	D(G(z)): 0.0392 / 0.8498
[2/5][850/1583]	Loss_D: 1.1876	Loss_G: 1.3117	D(x): 0.3737	D(G(z)): 0.0406 / 0.3352
[2/5][900/1583]	Loss_D: 0.5884	Loss_G: 1.5073	D(x): 0.6618	D(G(z)): 0.1186 / 0.2570
[2/5][950/1583]	Loss_D: 0.7464	Loss_G: 1.6026	D(x): 0.5838	D(G(z)): 0.1192 / 0.2465
[2/5][1000/1583]	Loss_D: 0.8106	Loss_G: 1.5935	D(x): 0.5482	D(G(z)): 0.1138 / 0.2576
[2/5][1050/1583]	Loss_D: 1.0301	Loss_G: 0.9337	D(x): 0.4632	D(G(z)): 0.1007 / 0.4397
[2/5][1100/1583]	Loss_D: 0.9579	Loss_G: 3.9689	D(x): 0.8162	D(G(z)): 0.4746 / 0.0296
[2/5][1150/1583]	Loss_D: 0.5497	Loss_G: 2.1319	D(x): 0.7154	D(G(z)): 0.1462 / 0.1508
[2/5][1200/1583]	Loss_D: 1.0015	Loss_G: 3.3431	D(x): 0.8876	D(G(z)): 0.5276 / 0.0481
[2/5][1250/1583]	Loss_D: 0.5956	Loss_G: 1.7995	D(x): 0.6556	D(G(z)): 0.1006 / 0.2051
[2/5][1300/1583]	Loss_D: 0.4923	Loss_G: 1.8547	D(x): 0.7482	D(G(z)): 0.1470 / 0.1892
[2/5][1350/1583]	Loss_D: 0.4625	Loss_G: 2.3031	D(x): 0.7836	D(G(z)): 0.1714 / 0.1290
[2/5][1400/1583]	Loss_D: 0.5085	Loss_G: 2.9163	D(x): 0.8776	D(G(z)): 0.2913 / 0.0668
[2/5][1450/1583]	Loss_D: 0.8956	Loss_G: 0.7772	D(x): 0.5154	D(G(z)): 0.1233 / 0.4948
[2/5][1500/1583]	Loss_D: 0.7101	Loss_G: 1.8090	D(x): 0.6657	D(G(z)): 0.2197 / 0.1924
[2/5][1550/1583]	Loss_D: 0.9502	Loss_G: 0.7733	D(x): 0.4884	D(G(z)): 0.0971 / 0.5111
[3/5][0/1583]	Loss_D: 0.6662	Loss_G: 2.4159	D(x): 0.7296	D(G(z)): 0.2653 / 0.1115
[3/5][50/1583]	Loss_D: 0.7752	Loss_G: 1.5012	D(x): 0.5814	D(G(z)): 0.1322 / 0.2680
[3/5][100/1583]	Loss_D: 0.7780	Loss_G: 3.6264	D(x): 0.8690	D(G(z)): 0.4333 / 0.0348
[3/5][150/1583]	Loss_D: 0.8705	Loss_G: 3.2710	D(x): 0.7855	D(G(z)): 0.4022 / 0.0535
[3/5][200/1583]	Loss_D: 0.5676	Loss_G: 2.3454	D(x): 0.7827	D(G(z)): 0.2484 / 0.1190
[3/5][250/1583]	Loss_D: 0.6734	Loss_G: 3.4769	D(x): 0.9145	D(G(z)): 0.4098 / 0.0381
[3/5][300/1583]	Loss_D: 0.6110	Loss_G: 1.5790	D(x): 0.6994	D(G(z)): 0.1888 / 0.2406
[3/5][350/1583]	Loss_D: 0.5078	Loss_G: 3.0860	D(x): 0.8976	D(G(z)): 0.3084 / 0.0566
[3/5][400/1583]	Loss_D: 0.8418	Loss_G: 4.8315	D(x): 0.9481	D(G(z)): 0.5053 / 0.0117
[3/5][450/1583]	Loss_D: 0.6444	Loss_G: 1.6765	D(x): 0.6694	D(G(z)): 0.1634 / 0.2183
[3/5][500/1583]	Loss_D: 1.5434	Loss_G: 4.1276	D(x): 0.9160	D(G(z)): 0.7060 / 0.0247
[3/5][550/1583]	Loss_D: 0.5104	Loss_G: 2.0650	D(x): 0.7646	D(G(z)): 0.1804 / 0.1609
[3/5][600/1583]	Loss_D: 1.3608	Loss_G: 4.7701	D(x): 0.9268	D(G(z)): 0.6700 / 0.0149
[3/5][650/1583]	Loss_D: 0.6086	Loss_G: 3.1287	D(x): 0.8555	D(G(z)): 0.3371 / 0.0551
[3/5][700/1583]	Loss_D: 0.7530	Loss_G: 3.6985	D(x): 0.7926	D(G(z)): 0.3518 / 0.0382
[3/5][750/1583]	Loss_D: 0.6241	Loss_G: 2.0707	D(x): 0.6915	D(G(z)): 0.1831 / 0.1598
[3/5][800/1583]	Loss_D: 0.6427	Loss_G: 3.5578	D(x): 0.9044	D(G(z)): 0.3754 / 0.0378
[3/5][850/1583]	Loss_D: 0.5044	Loss_G: 2.7291	D(x): 0.8211	D(G(z)): 0.2315 / 0.0878
[3/5][900/1583]	Loss_D: 0.6086	Loss_G: 3.5135	D(x): 0.8583	D(G(z)): 0.3297 / 0.0403
[3/5][950/1583]	Loss_D: 0.6418	Loss_G: 2.5607	D(x): 0.7214	D(G(z)): 0.2294 / 0.0991
[3/5][1000/1583]	Loss_D: 0.5335	Loss_G: 1.9207	D(x): 0.6979	D(G(z)): 0.1308 / 0.1890
[3/5][1050/1583]	Loss_D: 1.1061	Loss_G: 4.7492	D(x): 0.9441	D(G(z)): 0.6043 / 0.0137
[3/5][1100/1583]	Loss_D: 0.6131	Loss_G: 2.0501	D(x): 0.7609	D(G(z)): 0.2576 / 0.1575
[3/5][1150/1583]	Loss_D: 1.2984	Loss_G: 0.5656	D(x): 0.3524	D(G(z)): 0.0374 / 0.6265
[3/5][1200/1583]	Loss_D: 0.7278	Loss_G: 1.5974	D(x): 0.5992	D(G(z)): 0.1082 / 0.2450
[3/5][1250/1583]	Loss_D: 1.3815	Loss_G: 5.0266	D(x): 0.9395	D(G(z)): 0.6533 / 0.0109
[3/5][1300/1583]	Loss_D: 0.7119	Loss_G: 2.3102	D(x): 0.7915	D(G(z)): 0.3354 / 0.1222
[3/5][1350/1583]	Loss_D: 1.2739	Loss_G: 0.2641	D(x): 0.3607	D(G(z)): 0.0424 / 0.7833
[3/5][1400/1583]	Loss_D: 0.5943	Loss_G: 2.2916	D(x): 0.7708	D(G(z)): 0.2481 / 0.1273
[3/5][1450/1583]	Loss_D: 1.0932	Loss_G: 4.1132	D(x): 0.9061	D(G(z)): 0.5735 / 0.0278
[3/5][1500/1583]	Loss_D: 0.8520	Loss_G: 1.1129	D(x): 0.5194	D(G(z)): 0.1041 / 0.3780
[3/5][1550/1583]	Loss_D: 2.0437	Loss_G: 0.6944	D(x): 0.1913	D(G(z)): 0.0272 / 0.5869
[4/5][0/1583]	Loss_D: 0.6415	Loss_G: 2.2360	D(x): 0.7471	D(G(z)): 0.2587 / 0.1294
[4/5][50/1583]	Loss_D: 1.0841	Loss_G: 4.1800	D(x): 0.9679	D(G(z)): 0.5902 / 0.0243
[4/5][100/1583]	Loss_D: 0.7979	Loss_G: 1.3672	D(x): 0.5561	D(G(z)): 0.1175 / 0.3164
[4/5][150/1583]	Loss_D: 0.6017	Loss_G: 2.7566	D(x): 0.8011	D(G(z)): 0.2738 / 0.0822
[4/5][200/1583]	Loss_D: 0.7343	Loss_G: 3.3385	D(x): 0.8665	D(G(z)): 0.4031 / 0.0469
[4/5][250/1583]	Loss_D: 0.6154	Loss_G: 1.9663	D(x): 0.6328	D(G(z)): 0.0824 / 0.1788
[4/5][300/1583]	Loss_D: 0.9155	Loss_G: 2.7143	D(x): 0.7619	D(G(z)): 0.4138 / 0.0904
[4/5][350/1583]	Loss_D: 0.6844	Loss_G: 3.2434	D(x): 0.8772	D(G(z)): 0.3938 / 0.0474
[4/5][400/1583]	Loss_D: 0.6488	Loss_G: 3.6047	D(x): 0.8734	D(G(z)): 0.3633 / 0.0386
[4/5][450/1583]	Loss_D: 0.6037	Loss_G: 1.4727	D(x): 0.6563	D(G(z)): 0.1280 / 0.2684
[4/5][500/1583]	Loss_D: 0.5439	Loss_G: 2.5073	D(x): 0.7956	D(G(z)): 0.2403 / 0.0991
[4/5][550/1583]	Loss_D: 0.7262	Loss_G: 1.2411	D(x): 0.6007	D(G(z)): 0.1386 / 0.3340
[4/5][600/1583]	Loss_D: 0.8810	Loss_G: 4.6097	D(x): 0.9500	D(G(z)): 0.5232 / 0.0143
[4/5][650/1583]	Loss_D: 1.0356	Loss_G: 4.9807	D(x): 0.9327	D(G(z)): 0.5733 / 0.0137
[4/5][700/1583]	Loss_D: 0.5066	Loss_G: 2.3107	D(x): 0.7780	D(G(z)): 0.1920 / 0.1286
[4/5][750/1583]	Loss_D: 0.7037	Loss_G: 2.1193	D(x): 0.6649	D(G(z)): 0.1922 / 0.1494
[4/5][800/1583]	Loss_D: 0.4657	Loss_G: 2.3012	D(x): 0.7966	D(G(z)): 0.1882 / 0.1241
[4/5][850/1583]	Loss_D: 1.4215	Loss_G: 3.6018	D(x): 0.9601	D(G(z)): 0.6846 / 0.0428
[4/5][900/1583]	Loss_D: 0.7635	Loss_G: 3.0162	D(x): 0.8842	D(G(z)): 0.4314 / 0.0674
[4/5][950/1583]	Loss_D: 0.6663	Loss_G: 1.8579	D(x): 0.6075	D(G(z)): 0.0783 / 0.1948
[4/5][1000/1583]	Loss_D: 0.5781	Loss_G: 3.0544	D(x): 0.8583	D(G(z)): 0.3195 / 0.0565
[4/5][1050/1583]	Loss_D: 0.9002	Loss_G: 3.8505	D(x): 0.8861	D(G(z)): 0.4773 / 0.0313
[4/5][1100/1583]	Loss_D: 1.0272	Loss_G: 3.9188	D(x): 0.9282	D(G(z)): 0.5596 / 0.0341
[4/5][1150/1583]	Loss_D: 1.7347	Loss_G: 4.8050	D(x): 0.9351	D(G(z)): 0.7426 / 0.0117
[4/5][1200/1583]	Loss_D: 0.4433	Loss_G: 2.4808	D(x): 0.8453	D(G(z)): 0.2220 / 0.1046
[4/5][1250/1583]	Loss_D: 0.7117	Loss_G: 3.3558	D(x): 0.8295	D(G(z)): 0.3617 / 0.0496
[4/5][1300/1583]	Loss_D: 0.8814	Loss_G: 1.9139	D(x): 0.5378	D(G(z)): 0.1337 / 0.1938
[4/5][1350/1583]	Loss_D: 0.8146	Loss_G: 1.1267	D(x): 0.5384	D(G(z)): 0.1093 / 0.3813
[4/5][1400/1583]	Loss_D: 0.6365	Loss_G: 1.6231	D(x): 0.7090	D(G(z)): 0.2112 / 0.2331
[4/5][1450/1583]	Loss_D: 0.7084	Loss_G: 2.7831	D(x): 0.9043	D(G(z)): 0.4124 / 0.0820
[4/5][1500/1583]	Loss_D: 0.4315	Loss_G: 2.3095	D(x): 0.7849	D(G(z)): 0.1526 / 0.1271
[4/5][1550/1583]	Loss_D: 1.1658	Loss_G: 3.3444	D(x): 0.8109	D(G(z)): 0.5244 / 0.0604
In [18]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
In [19]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
Out[19]:
In [20]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
In [ ]: